{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/facebookresearch/esm/blob/master/examples/variant_prediction.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variant prediction with ESM\n",
    "\n",
    "This tutorial demonstrates how to train a simple variant predictor, i.e. we predict the biological activity of mutations of a protein, using fixed embeddings from ESM. You can adopt a similar protocol to train a model for any downstream task, even with limited data.\n",
    "\n",
    "We will use a simple classifier in sklearn (or \"head\" on top of the transformer features) to predict the mutation effect from precomputed ESM embeddings. The embeddings for your dataset can be dumped once using a GPU. Then, the rest of your analysis can be done on CPU. \n",
    "\n",
    "### Background\n",
    "\n",
    "In this particular example, we will train a model to predict the activity of ß-lactamase variants.\n",
    "\n",
    "We provide the training in `examples/P62593.fasta`, a FASTA file where each entry contains:\n",
    "- the mutated ß-lactamase sequence, where a single residue is mutated (swapped with another amino acid)\n",
    "- the target value in the last field of the header, describing the scaled effect of the mutation\n",
    "\n",
    "The [data originally comes](https://github.com/FowlerLab/Envision2017/blob/master/data/dmsTraining_2017-02-20.csv) from a deep mutational scan and was released with the Envision paper (Gray, et al. 2018)\n",
    "\n",
    "### Goals\n",
    "- Obtain an embedding (fixed-dimensional vector representation) for each mutated sequence.\n",
    "- Train a regression model in sklearn that can predict the \"effect\" score given the embedding.\n",
    "\n",
    "\n",
    "### Prerequisites\n",
    "- You will need the following modules : tqdm, matplotlib, numpy, pandas, seaborn, scipy, scikit-learn\n",
    "- You have obtained sequence embeddings for ß-lactamase as described in the README, either by:\n",
    "    - running `python extract.py esm1_t34_670M_UR50S examples/P62593.fasta examples/P62593_reprs/ --repr_layers 34 --include mean`  OR \n",
    "    - for your convenience we precomputed the embeddings and you can download them from [here](https://dl.fbaipublicfiles.com/fair-esm/examples/P62593_reprs.tar.gz) - see below to download this right here from in this notebook\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of Contents\n",
    "1. [Prelims](#prelims)\n",
    "1. [Loading Embeddings](#load_embeddings)\n",
    "1. [Visualizing Embeddings](#viz_embeddings)\n",
    "1. [Initializing / Running Grid Search](#grid_search)\n",
    "1. [Browse Grid Search Results](#browse)\n",
    "1. [Evaluating Results](#eval)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='prelims'></a>\n",
    "## Prelims"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you are using colab, then uncomment and run the cell below.\n",
    "It will pip install the `esm` code, fetch the fasta file and the pre-computed embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install git+https://github.com/facebookresearch/esm.git\n",
    "# !curl -O https://dl.fbaipublicfiles.com/fair-esm/examples/P62593_reprs.tar.gz\n",
    "# !tar -xzf P62593_reprs.tar.gz\n",
    "# !curl -O https://dl.fbaipublicfiles.com/fair-esm/examples/P62593.fasta\n",
    "# !pwd\n",
    "# !ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As an alternative to pip installing, you can just put the repo on your python path:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# PATH_TO_REPO = \"../\"\n",
    "# sys.path.append(PATH_TO_REPO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "import esm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from sklearn.model_selection import GridSearchCV, train_test_split\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor\n",
    "from sklearn.svm import SVC, SVR\n",
    "from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.linear_model import LogisticRegression, SGDRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add the path to your embeddings here:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FASTA_PATH = \"./P62593.fasta\" # Path to P62593.fasta\n",
    "EMB_PATH = \"./P62593_reprs/\" # Path to directory of embeddings for P62593.fasta\n",
    "EMB_LAYER = 34"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='load_embeddings'></a>\n",
    "## Load embeddings (Xs) and target effects (ys)\n",
    "Our FASTA file is formatted as such:\n",
    "```\n",
    ">{index}|{mutation_id}|{effect}\n",
    "{seq}\n",
    "```\n",
    "We will be extracting the effect from each entry.\n",
    "\n",
    "Our embeddings are stored with the file name from fasta header: `{index}|{mutation_id}|{effect}.pt`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = []\n",
    "Xs = []\n",
    "for header, _seq in esm.data.read_fasta(FASTA_PATH):\n",
    "    scaled_effect = header.split('|')[-1]\n",
    "    ys.append(float(scaled_effect))\n",
    "    fn = f'{EMB_PATH}/{header[1:]}.pt'\n",
    "    embs = torch.load(fn)\n",
    "    Xs.append(embs['mean_representations'][EMB_LAYER])\n",
    "Xs = torch.stack(Xs, dim=0).numpy()\n",
    "print(len(ys))\n",
    "print(Xs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train / Test Split\n",
    "\n",
    "Here we choose to follow the Envision paper, using 80% of the data for training, but we actually found that pre-trained ESM embeddings require fewer downstream training examples to reach the same level of performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_size = 0.8\n",
    "Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xs_train.shape, Xs_test.shape, len(ys_train), len(ys_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PCA\n",
    "\n",
    "Principal Component Analysis is a popular technique for dimensionality reduction. Given `n_features` (1280 in our case), PCA computes a new set of `X` that \"best explain the data.\" We've found that this enables downstream models to be trained faster with minimal loss in performance.  \n",
    "\n",
    "Here, we set `X` to 60, but feel free to change it!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(60)\n",
    "Xs_train_pca = pca.fit_transform(Xs_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='viz_embeddings'></a>\n",
    "## Visualize Embeddings\n",
    "\n",
    "Here, we plot the first two principal components on the x- and y- axes. Each point is then colored by its scaled effect (what we want to predict).\n",
    "\n",
    "Visually, we can see a separation based on color/effect, suggesting that our representations are useful for this task, without any task-specific training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_dims = (7, 6)\n",
    "fig, ax = plt.subplots(figsize=fig_dims)\n",
    "sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.')\n",
    "ax.set_xlabel('PCA first principal component')\n",
    "ax.set_ylabel('PCA second principal component')\n",
    "plt.colorbar(sc, label='Variant Effect')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='grid_search'></a>\n",
    "\n",
    "## Initialize / Run GridSearch\n",
    "\n",
    "We will run grid search for three different regression models:\n",
    "1. [K-nearest-neighbors](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html)\n",
    "2. [SVM](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html?highlight=svr#sklearn.svm.SVR)\n",
    "3. [Random Forest Regressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html?highlight=randomforestregressor#sklearn.ensemble.RandomForestRegressor)\n",
    "\n",
    "Here, we will be using the PCA-projected features because we observe it does just as well as `Xs` while allowing for faster training. You can easily swap it out for `Xs`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize grids for different regression techniques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knn_grid = {\n",
    "    'n_neighbors': [5, 10],\n",
    "    'weights': ['uniform', 'distance'],\n",
    "    'algorithm': ['ball_tree', 'kd_tree', 'brute'],\n",
    "    'leaf_size' : [15, 30],\n",
    "    'p' : [1, 2],\n",
    "}\n",
    "\n",
    "svm_grid = {\n",
    "    'C' : [0.1, 1.0, 10.0],\n",
    "    'kernel' :['linear', 'poly', 'rbf', 'sigmoid'],\n",
    "    'degree' : [3],\n",
    "    'gamma': ['scale'],\n",
    "}\n",
    "\n",
    "rfr_grid = {\n",
    "    'n_estimators' : [20],\n",
    "    'criterion' : ['mse', 'mae'],\n",
    "    'max_features': ['sqrt', 'log2'],\n",
    "    'min_samples_split' : [5, 10],\n",
    "    'min_samples_leaf': [1, 4]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_list = [KNeighborsRegressor, SVR, RandomForestRegressor]\n",
    "param_grid_list = [knn_grid, svm_grid, rfr_grid]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Grid Search \n",
    "\n",
    "(will take a few minutes on a single core)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_list = []\n",
    "grid_list = []\n",
    "for cls_name, param_grid in zip(cls_list, param_grid_list):\n",
    "    print(cls_name)\n",
    "    grid = GridSearchCV(\n",
    "        estimator = cls_name(), \n",
    "        param_grid = param_grid,\n",
    "        scoring = 'r2',\n",
    "        verbose = 1,\n",
    "        n_jobs = -1 # use all available cores\n",
    "    )\n",
    "    grid.fit(Xs_train_pca, ys_train)\n",
    "    result_list.append(pd.DataFrame.from_dict(grid.cv_results_))\n",
    "    grid_list.append(grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='browse'></a>\n",
    "## Browse the Sweep Results\n",
    "\n",
    "The following tables show the top 5 parameter settings, based on `mean_test_score`. Given our setup, this should really be thought of as `validation_score`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### K Nearest Neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_list[0].sort_values('rank_test_score')[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_list[1].sort_values('rank_test_score')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_list[2].sort_values('rank_test_score')[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='eval'></a>\n",
    "## Evaluation\n",
    "\n",
    "Now that we have run grid search, each `grid` object contains a `best_estimator_`.\n",
    "\n",
    "We can use this to evaluate the correlation between our predictions and the true effect scores on the held-out validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xs_test_pca = pca.transform(Xs_test)\n",
    "for grid in grid_list:\n",
    "    print(grid.best_estimator_)\n",
    "    print()\n",
    "    preds = grid.predict(Xs_test_pca)\n",
    "    print(f'{scipy.stats.spearmanr(ys_test, preds)}')\n",
    "    print('\\n', '-' * 80, '\\n')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The SVM performs the best on the `test` set, with a spearman rho of 0.80! \n",
    "\n",
    "This is in line with our grid-search results, where it also had the best `validation` performance.\n",
    "\n",
    "In conclusion, our downstream model was able to use fixed pre-trained ESM embedding representations and obtain a decent result.\n",
    "\n",
    "(For reference, we report correlation of 0.89 in Table 7 of [the paper](https://www.biorxiv.org/content/10.1101/622803v3), but this was achieved by fine-tuning the model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
